import numpy as np
from scipy import stats
from numba import njit
import time_dependent as td
import scipy
import scipy.optimize as opt


def transition_matrix(mu, sig, distribution, tau, x_grid):
    nx = len(x_grid)

    if distribution is 'laplace':
        cdf = lambda x: stats.laplace.cdf((x-mu)/(sig/np.sqrt(2)))
    else:
        cdf = lambda x: stats.norm.cdf((x - mu) / sig)

    Q = np.zeros([nx, nx])
    Q[:, 0] = cdf(0.5*(x_grid[0]+x_grid[1]) - x_grid)
    Q[:, 1:nx-1] = cdf(0.5*(x_grid[np.newaxis, 1:nx-1]+x_grid[np.newaxis, 2:nx]) - x_grid[:, np.newaxis]) - \
                   cdf(0.5*(x_grid[np.newaxis, 0:nx-2]+x_grid[np.newaxis, 1:nx-1]) - x_grid[:, np.newaxis])
    Q[:, nx-1] = 1 - cdf(0.5*(x_grid[nx-1]+x_grid[nx-2]) - x_grid)

    Q = (1-tau)*np.identity(len(x_grid)) + tau*Q
    return Q


def parabolic_max(x0, x1, x2, y0, y1, y2):
    # uses parabolic interpolation so the max can fall between grid points
    c2 = y0 / ((x0 - x1) * (x0 - x2)) + y1 / ((x1 - x0) * (x1 - x2)) + y2 / ((x2 - x1) * (x2 - x0))
    c1 = -y0 * (x1 + x2) / ((x0 - x1) * (x0 - x2)) - y1 * (x0 + x2) / ((x1 - x0) * (x1 - x2)) - y2 * (x1 + x0) / (
            (x2 - x1) * (x2 - x0))
    c0 = y0 * (x1 * x2) / ((x0 - x1) * (x0 - x2)) \
         + y1 * (x0 * x2) / ((x1 - x0) * (x1 - x2)) + y2 * (x1 * x0) / ((x2 - x1) * (x2 - x0))
    #
    x_max = -0.5 * c1 / c2
    y_max = c2 * (x_max ** 2) + c1 * x_max + c0
    return x_max, y_max


def backward_iteration(V_p, Q, x_grid, r, mu, tau, distribution, sig, la, mu_k, sig_k, gap, discount, shift, elast, compute_Q=False):
    if compute_Q:
        # transition matrix for non adjusters
        Q_not = transition_matrix(-mu, sig, distribution, tau, x_grid)
    else:
        Q_not = Q

    if elast is None:
        profit = -(x_grid - gap) ** 2
    else:
        profit = ((elast / (elast - 1)) ** (1 - elast)) * (np.exp((1 - elast) * (x_grid - gap)) - ((elast - 1) / elast) * np.exp(-elast * (x_grid - gap)))

    V_not = np.exp(discount + shift) * profit + (1 / (1 + r)) * Q_not @ V_p

    i_max = np.argmax(V_not)

    # parabolic interpolation so the maximum can fall between grid points
    x_max, V_adj = parabolic_max(x_grid[i_max - 1], x_grid[i_max], x_grid[i_max + 1],
                                 V_not[i_max - 1], V_not[i_max], V_not[i_max + 1])

    # interpolate to find transition for adjusters
    i1 = np.where(x_grid > x_max)[0][0]
    i0 = i1 - 1
    w0 = (x_grid[i1] - x_max) / (x_grid[i1] - x_grid[i0])
    Q_adj = w0 * Q_not[i0, :] + (1 - w0) * Q_not[i1, :]

    ind_adj = [i0, i1]

    # Adjustment probability
    diff = np.exp(-discount) * np.maximum(V_adj - V_not, 1e-6)
    diff = np.log(diff)
    q = la + (1 - la) * stats.norm.cdf((diff - mu_k) / sig_k)

    # Expected menu cost payment
    EK = np.exp(discount) * (1 - la) * (np.exp(mu_k + 0.5 * sig_k ** 2) * stats.norm.cdf((diff - mu_k - sig_k ** 2) / sig_k))

    V = q * V_adj + (1 - q) * V_not - EK

    if elast is None:
        p = (1 - q) * x_grid + q * x_max - mu
        p_star = 0 * x_grid
        profit = (1 - q) * (-(x_grid - gap) ** 2) + q * (-(x_max - gap) ** 2)
    else:
        p = (1 - q) * np.exp((1 - elast) * (x_grid - mu)) + q * np.exp((1 - elast) * (x_max - mu))
        p_star = (1 - q) * np.exp(-elast * (x_grid - mu)) + q * np.exp(-elast * (x_max - mu))
        profit = ((elast / (elast - 1)) ** (1 - elast)) * (
            (1 - q) * (np.exp((1 - elast) * (x_grid - gap)) - ((elast - 1) / elast) * np.exp(-elast * (x_grid - gap))) +
            q * (np.exp((1 - elast) * (x_max - gap)) - ((elast - 1) / elast) * np.exp(-elast * (x_max - gap)))
        )

    output = {'p': p,
              'p_star': p_star,
              'p_reset': x_max*np.ones_like(x_grid),
              'pi': (x_max - x_grid)*q,
              'freq': q,
              'profit': profit,
              'p_lag': x_grid
              }

    return V, q, x_max, Q_adj, ind_adj, Q_not, output


def pol_ss(parameters, x_grid, V_seed=None, iterate_on_policy=False, tol=1e-6, maxit = 100000):

    if V_seed is None:
        V_p = np.zeros_like(x_grid)
    else:
        V_p = V_seed

    q_p = np.zeros_like(V_p)

    mu = parameters['mu']
    sig = parameters['sig']
    tau = parameters['tau']
    distribution = parameters['distribution']

    Q = transition_matrix(-mu, sig, distribution, tau, x_grid)

    for it in range(maxit):
        V, q, x_max, Q_adj, ind_adj, Q_not, output_ss = backward_iteration(V_p, Q, x_grid, **parameters, compute_Q=False)

        if np.max(np.abs(V-V_p)) < tol:
            break

        if iterate_on_policy:
            if np.max(np.abs(q-q_p)) < tol:
                break

        V_p = V
        q_p = q

    else:
        raise ValueError(f'No convergence after {maxit} backward iterations!')

    Pi = q[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q[:, np.newaxis]) * Q_not

    return V, Pi, q, x_max, ind_adj, Q, output_ss


# finds steady state distribution
def dist_ss(Pi, x_grid, g_seed=None, tol=1e-12, maxit=100000):
    if g_seed is None:
        g_p = np.ones_like(x_grid)
        g_p = g_p/np.sum(g_p)
    else:
        g_p = g_seed

    for it in range(maxit):
        g = Pi.T @ g_p

        if it % 10 and (np.max(np.abs(g-g_p)) < tol):
            break

        g_p = g

    return g


# finds steady state
def steady_state(parameters, nx, iterate_on_policy=False, V_seed=None, g_seed=None, x_grid=None, grid_bounds=3.5):

    if x_grid is None:
        x_grid = get_grid(parameters, nx, bounds=grid_bounds)

    V, Pi, q, x_max, ind_adj, Q, output = pol_ss(parameters, x_grid, V_seed=V_seed, iterate_on_policy=iterate_on_policy)
    g = dist_ss(Pi, x_grid, g_seed=g_seed)

    stats = statistics(q, x_max, ind_adj, g, x_grid, Q)

    ss = {'V': V,
          'Pi': Pi,
          'q': q,
          'x_max': x_max,
          'stats': stats,
          'g': g,
          'Q': Q,
          'x_grid': x_grid,
          'output': output
          }

    if parameters['elast'] is not None:
        ss['P'] = (g @ output['p']) ** (1 / (1 - parameters['elast']))
        ss['P_star'] = (g @ output['p_star']) ** (-1 / parameters['elast'])
        ss['Delta'] = (ss['P_star'] / ss['P']) ** parameters['elast']
        ss['Profits'] = g @ output['profit']

    return ss


def backward_iteration_transition(parameters, V_ss, Pi_ss, Q_ss, output_ss, x_grid, T, h, shock, output_list):
    dPi = list()
    dy = list()

    V_p = V_ss

    ss_param_value = parameters[shock]
    parameters[shock] = parameters[shock] + h
    compute_Q = True
    for t in range(T):
        if t == 1:
            parameters[shock] = ss_param_value
            compute_Q = False

        V, q, _, Q_adj, _, Q_not, output = backward_iteration(V_p, Q_ss, x_grid, **parameters, compute_Q=compute_Q)
        Pi = q[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q[:, np.newaxis]) * Q_not

        dPi.append((Pi - Pi_ss) / h)
        dy.append({k: (output[k] - output_ss[k]) / h for k in output_list})

        V_p = V

    return dy, dPi


def forward_iteration_transition(Pi_ss, g_ss, output_ss, dy, dPi, T, output_list):

    F = {k: np.zeros((T, T)) for k in output_list}

    for k in output_list:
        U, V = [], []
        v = output_ss[k]
        for t in range(T):
            U.append(dPi[t].T @ g_ss)
            if t > 0:
                V.append(v)
                v = Pi_ss @ v

        for t in range(T):
            for s in range(T):
                if t == 0:
                    F[k][t, s] = dy[s][k] @ g_ss
                else:
                    F[k][t, s] = V[t - 1] @ U[s]

    return F


# taken from sequence space codes
def J_from_F(F):
    J = F.copy()
    for t in range(1, J.shape[1]):
        J[1:, t] += J[:-1, t - 1]
    return J


def compute_jacobian(parameters, nx, T, h, shock_list, output_list):

    ss = steady_state(parameters, nx)

    V, Pi, Q, output, x_grid, g = ss['V'], ss['Pi'], ss['Q'], ss['output'], ss['x_grid'], ss['g']

    J = {out: {} for out in output_list}

    for shock in shock_list:
        dy, dPi = backward_iteration_transition(parameters, V, Pi, Q, output, x_grid, T, h, shock, output_list)
        F = forward_iteration_transition(Pi, g, output, dy, dPi, T, output_list)
        for out in output_list:
            J[out][shock] = J_from_F(F[out])

    if parameters['elast'] is not None:
        if 'p' in output_list:
            for shock in shock_list:
                J['p'][shock] = J['p'][shock] / ((1 - parameters['elast']) * (output['p'] @ g))
        if 'p_star' in output_list:
            for shock in shock_list:
                J['p_star'][shock] = J['p_star'][shock] / ((-parameters['elast']) * (output['p_star'] @ g))

    return J, ss


def compute_eigenvalues(xbar, la, sig, tau, nx=300):
    x_grid = np.linspace(-xbar, xbar, nx)
    dx = x_grid[1] - x_grid[0]
    phi = stats.norm.cdf((np.arange(-1, nx) + 0.5) * dx, scale=sig)
    phi = phi[1:] - phi[:-1]
    Phi = (1 - la) * scipy.linalg.toeplitz(phi)
    Phi = tau * Phi + (1 - tau) * np.eye(len(x_grid))
    return np.sort(np.linalg.eigvalsh(Phi))[::-1]


def compute_td_equivalence(parameters, nx, T, ss=None, h=1e-4):
    # check parameter values required for equivalence to hold
    assert parameters['mu'] == 0
    # assert parameters['tau'] == 1
    assert parameters['elast'] is None

    beta = 1 / (1 + parameters['r'])

    if ss is None:
        ss = steady_state(parameters, nx)

    q = ss['q']
    x_grid = ss['x_grid']
    Q = ss['Q']

    F_int = np.zeros(T)
    F_ext = np.zeros(T)

    g_ss = ss['g']
    x_max_ss = ss['x_max']

    # intensive irf
    x_max_h = ss['x_max'] + h
    i1 = np.where(x_grid > x_max_h)[0][0]
    i0 = i1 - 1
    w0 = (x_grid[i1] - x_max_h) / (x_grid[i1] - x_grid[i0])
    Q_adj = w0 * Q[i0, :] + (1 - w0) * Q[i1, :]
    Pi = q[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q[:, np.newaxis]) * Q
    g = g_ss.copy()
    for t in range(T):
        g = Pi.T @ g
        F_int[t] = (g - g_ss) @ x_grid / h

    w_int = F_int[-1]
    F_int /= w_int
    f_int = F_int
    f_int[1:] = F_int[1:] - F_int[:-1]

    # extensive irf
    q = np.interp(x_grid, x_grid + h, q)
    i1 = np.where(x_grid > x_max_ss)[0][0]
    i0 = i1 - 1
    w0 = (x_grid[i1] - x_max_ss) / (x_grid[i1] - x_grid[i0])
    Q_adj = w0 * Q[i0, :] + (1 - w0) * Q[i1, :]
    Pi = q[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q[:, np.newaxis]) * Q
    g = g_ss.copy()
    for t in range(T):
        g = Pi.T @ g
        F_ext[t] = (g - g_ss) @ x_grid / h

    w_ext = F_ext[-1]
    F_ext /= w_ext
    f_ext = F_ext
    f_ext[1:] = F_ext[1:] - F_ext[:-1]

    sum_weights = w_ext + w_int
    # print(sum_weights)
    w_ext /= sum_weights
    w_int /= sum_weights

    # compute Jacobian
    J_ext = td.general_td_jacobian(f=f_ext, beta=beta)
    J_int = td.general_td_jacobian(f=f_int, beta=beta)
    J = w_ext * J_ext + w_int * J_int

    return J, ss, J_ext, J_int, f_ext, f_int, w_ext, w_int


def compute_td_equivalence_trend(parameters, nx, T, ss=None, h=1e-4):
    # check parameter values required for equivalence to hold
    assert parameters['elast'] is None

    beta = 1 / (1 + parameters['r'])

    if ss is None:
        ss = steady_state(parameters, nx)

    q = ss['q']
    x_grid = ss['x_grid']
    Q = ss['Q']

    F_int = np.zeros(T)
    F_lower = np.zeros(T)
    F_upper = np.zeros(T)

    g_ss = ss['g']
    x_max_ss = ss['x_max']

    # intensive irf
    x_max_h = ss['x_max'] + h
    i1 = np.where(x_grid > x_max_h)[0][0]
    i0 = i1 - 1
    w0 = (x_grid[i1] - x_max_h) / (x_grid[i1] - x_grid[i0])
    Q_adj = w0 * Q[i0, :] + (1 - w0) * Q[i1, :]
    Pi = q[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q[:, np.newaxis]) * Q
    g = g_ss.copy()
    for t in range(T):
        g = Pi.T @ g
        F_int[t] = (g - g_ss) @ x_grid / h

    w_int = F_int[-1]
    F_int /= w_int
    f_int = F_int
    f_int[1:] = F_int[1:] - F_int[:-1]

    # lower bound irf
    i_neg = x_grid < 0
    q_low = np.interp(x_grid[i_neg], x_grid[i_neg] + h, q[i_neg])
    i1 = np.where(x_grid > x_max_ss)[0][0]
    i0 = i1 - 1
    w0 = (x_grid[i1] - x_max_ss) / (x_grid[i1] - x_grid[i0])
    Q_adj = w0 * Q[i0, :] + (1 - w0) * Q[i1, :]
    q_new = q.copy()
    q_new[i_neg] = q_low
    Pi = q_new[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q_new[:, np.newaxis]) * Q
    g = g_ss.copy()
    for t in range(T):
        g = Pi.T @ g
        F_lower[t] = (g - g_ss) @ x_grid / h

    w_lower = F_lower[-1]
    F_lower /= w_lower
    f_lower = F_lower
    f_lower[1:] = f_lower[1:] - f_lower[:-1]

    # lower bound irf
    i_pos = x_grid > 0
    q_up = np.interp(x_grid[i_pos], x_grid[i_pos] + h, q[i_pos])
    i1 = np.where(x_grid > x_max_ss)[0][0]
    i0 = i1 - 1
    w0 = (x_grid[i1] - x_max_ss) / (x_grid[i1] - x_grid[i0])
    Q_adj = w0 * Q[i0, :] + (1 - w0) * Q[i1, :]
    q_new = q.copy()
    q_new[i_pos] = q_up
    Pi = q_new[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q_new[:, np.newaxis]) * Q
    g = g_ss.copy()
    for t in range(T):
        g = Pi.T @ g
        F_upper[t] = (g - g_ss) @ x_grid / h

    w_upper = F_upper[-1]
    F_upper /= w_upper
    f_upper = F_upper
    f_upper[1:] = f_upper[1:] - f_upper[:-1]

    sum_weights = w_int + w_lower + w_upper
    # print(sum_weights)
    w_int /= sum_weights
    w_lower /= sum_weights
    w_upper /= sum_weights

    # compute Jacobian
    J_int = td.general_td_jacobian(f=f_int, beta=beta)
    J_lower = td.general_td_jacobian(f=f_lower, beta=beta)
    J_upper = td.general_td_jacobian(f=f_upper, beta=beta)
    J = w_int * J_int + w_lower * J_lower + w_upper * J_upper

    return J, ss, J_int, J_lower, J_upper, f_int, f_lower, f_upper, w_int, w_lower, w_upper


def compute_Es(parameters, nx, ss=None, tmax=10):
    if ss is None:
        ss = steady_state(parameters, nx)

    q = ss['q']
    x_grid = ss['x_grid']
    Q = ss['Q']

    # find xbar and index of x=0
    xbar = np.interp(0.5, q[nx // 2:], x_grid[nx // 2:])
    i1 = np.where(x_grid > 0)[0][0]
    i0 = i1 - 1
    assert np.isclose(x_grid[i1], -x_grid[i0])  # this can be relaxed if another numerical derivative scheme is used

    # compute E's
    E = []
    E_xbar = np.zeros(tmax)
    E_prime = np.zeros(tmax)
    for t in range(tmax):
        if t == 0:
            E_next = x_grid
        else:
            E_next = Q @ ((1 - q) * E[t - 1])
        E_xbar[t] = np.interp(xbar, x_grid, E_next)
        E_prime[t] = (E_next[i1] - E_next[i0]) / (x_grid[i1] - x_grid[i0])
        E.append(E_next)

    P = []
    P_zero = np.zeros(tmax)
    P_xbar = np.zeros(tmax)
    G = np.eye(nx)
    for t in range(tmax):
        if t == 0:
            P_next = np.ones(nx)
        else:
            G = (1 - q[:, np.newaxis]) * (Q.T @ G)
            P_next = np.sum(G, axis=0)
        P_xbar[t] = np.interp(xbar, x_grid, P_next)
        P_zero[t] = np.interp(0, x_grid, P_next)
        P.append(P_next)

    return E, E_prime, E_xbar, P, P_zero, P_xbar


def compute_PC(J, parameters=None, ss=None, h=1e-4, T_extend=1000, permanent_shock_to_normalize=False, cut_back=False, return_nominal=False):
    T = len(J)
    if T > T_extend:
        T_extend = T + 100

    if permanent_shock_to_normalize:
        # normalize by permanent shock to make sure far out columns sum to 1
        nx = len(ss['x_grid'])
        permanent_shock = permanent_gap_shock(parameters, nx, T, h, ss['x_grid'], ss['Pi'], ss['g'])
        J = J * permanent_shock[:, np.newaxis] / np.sum(J, axis=1)[:, np.newaxis]

    # extend matrix to increase precision
    J = extend_matrix(J, T_extend)

    # we want to solve (I-J) @ J_real = J, but I-J is ill-conditioned
    # use the Calvo approximation (without imposing it!) to find a matrix P such that P @ (I-J) is well-conditioned
    # then solve (P @ (I-J)) @ J_real = P @ J or A @ J_real = B
    if parameters is not None:
        beta = 1 / (1 + parameters['r'])
    else:
        beta = 0.99
    J_calvo = np.cumsum(td.calvo_PC(1, beta, T_extend), axis=0)
    P = np.linalg.solve(J.T, J_calvo.T).T
    # P = np.linalg.solve(J, P)
    A = P @ (np.eye(T_extend) - J)
    B = J_calvo
    # B = np.linalg.solve(J, B)
    J_pc = np.linalg.solve(A, B)
    J_pc[1:, :] -= J_pc[:-1, :]

    if cut_back:
        J_pc = J_pc[:T, :T]
        J = J[:T, :T]

    if return_nominal:
        return J_pc, J

    return J_pc


def mit_shock(parameters, nx, shock_paths, output_list, ss=None, return_only_irf=True, initial_distribution=None):
    parameters = parameters.copy()
    if ss is None:
        ss = steady_state(parameters, nx)

    V_ss = ss['V']
    x_grid = ss['x_grid']
    g_ss = ss['g']
    Q_ss = ss['Q']

    T = [len(shock_paths[k]) for k in shock_paths.keys()][0]
    irf = {}

    Pi_transition = list()
    q_transition = list()
    x_max_transition = list()
    output_transition = list()
    g_transition = list()

    if {'mu', 'sig', 'tau'} & set(shock_paths.keys()):
        compute_Q = True
    else:
        compute_Q = False

    # if there is a permanent change in parameter values, compute the terminal V_ss
    if not np.all([np.isclose(parameters[k], shock_paths[k][-1]) for k in shock_paths.keys()]):
        parameters_final = parameters.copy()
        parameters_final.update({k: shock_paths[k][-1] for k in shock_paths.keys()})
        ss_final = steady_state(parameters_final, nx, x_grid=x_grid)
        V_p = ss_final['V']
    else:
        V_p = V_ss

    for t in range(T-1, -1, -1):
        for var, path in shock_paths.items():
            parameters[var] = path[t]
        V_p, q, x_max, Q_adj, _, Q_not, output = backward_iteration(V_p, Q_ss, x_grid, **parameters, compute_Q=compute_Q)
        Pi = q[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q[:, np.newaxis]) * Q_not
        Pi_transition.append(Pi)
        q_transition.append(q)
        x_max_transition.append(x_max)
        output_transition.append(output)

    Pi_transition.reverse()
    q_transition.reverse()
    x_max_transition.reverse()
    output_transition.reverse()

    if initial_distribution is None:
        g = g_ss
    else:
        g = initial_distribution

    for t in range(T):
        for out in output_list:
            if t == 0:
                irf[out] = np.empty(T)
            irf[out][t] = np.dot(g, output_transition[t][out])
        g_transition.append(g)
        g = Pi_transition[t].T @ g

    if return_only_irf is False:
        return irf, Pi_transition, q_transition, x_max_transition, g_transition

    return irf


def permanent_gap_shock(parameters, nx, T, h, x_grid=None, Pi_ss=None, g_ss=None):
    if (x_grid is None) or (g_ss is None) or (Pi_ss is None):
        ss = steady_state(parameters, nx)
        x_grid = ss['x_grid']
        Pi_ss = ss['Pi']
        g_ss = ss['g']

    g = np.interp(x_grid, x_grid - h, g_ss)
    g = g/np.sum(g)
    p = np.empty(T)

    for t in range(T):
        g = Pi_ss.T @ g
        p[t] = (g - g_ss) @ x_grid

    return (p+h)/h


def extend_matrix(X, T_extend):
    T = len(X)
    t_middle = np.int(T / 2)
    middle_column = X[:, t_middle]
    X_large = np.zeros((T_extend, T_extend))
    X_large[:T, :T] = X
    for i in range(T_extend - t_middle):
        t_diag = t_middle + i
        t0 = t_diag - t_middle
        t1 = min(t_diag - t_middle + T, T_extend)
        X_large[t0:t1, t_diag] = middle_column[0:t1 - t0]
    return X_large


def get_grid(parameters, nx, bounds=3.5):
    grid_wide_enough = False
    xmax = 1
    xmin = -1

    drift = -parameters['mu']
    sig = parameters['sig']

    while grid_wide_enough is False:
        x_grid = np.linspace(xmin, xmax, nx)

        _, _, q, x_max, _, _, _ = pol_ss(parameters, x_grid, iterate_on_policy=True, tol=1e-9)

        if (len(np.where((q<0.99)*(x_grid<x_max))[0]) > 0) and (len(np.where((q>0.99)*(x_grid>x_max))[0]) > 0):
            grid_wide_enough = 1
        else:
            xmin = 1.25*xmin
            xmax = 1.25*xmax
            # print(xmax)

    imin = np.where((q<0.99)*(x_grid<x_max))[0][0]-1
    imax = np.where((q>0.99)*(x_grid>x_max))[0][0]

    xmin = x_grid[imin] - bounds*sig + drift
    xmax = x_grid[imax] + bounds*sig + drift

    x_grid = np.linspace(xmin, xmax, nx)

    return x_grid


def statistics(q, x_max, ind_adj, g, x_grid, Q, t_max=50):
    freq = np.sum(q * g)
    dx = x_max - x_grid
    abs_dx = np.abs(x_max - x_grid)

    weight = q * g / freq
    # adjusted prices are sent to ind_adj, so it makes no sense to count them in adjustments, makes a difference if tau < 1
    weight[ind_adj] = 0
    weight = weight / np.sum(weight)

    mean_dx = dx @ weight
    mean_abs_dx = abs_dx @ weight

    std_dx = np.sqrt(((dx - mean_dx) ** 2) @ weight)
    kurt_dx = (((dx - mean_dx) ** 4) @ weight) / (std_dx ** 4)

    # sort to obtain median
    abs_dx = np.abs(dx)
    ind = np.argsort(abs_dx)
    abs_dx = np.sort(abs_dx)
    cum_weight = np.cumsum(weight[ind])

    i_med = np.where(cum_weight > 0.5)[0][0]
    med_abs_dx = 0.5 * abs_dx[i_med] + 0.5 * abs_dx[i_med - 1]

    i1 = np.where(x_grid > x_max)[0][0]
    i0 = i1 - 1
    w0 = (x_grid[i1] - x_max) / (x_grid[i1] - x_grid[i0])

    g0 = np.zeros(x_grid.shape)
    g0[i0] = w0
    g0[i1] = 1 - w0
    survival = np.zeros(t_max)
    survival[0] = 1
    for t in range(1, t_max):
        g0 = Q.T @ g0
        g0 = (1 - q) * g0
        survival[t] = np.sum(g0)
    hazards = 1 - survival[1:] / survival[:-1]

    stats = {'freq': freq,
             'mean_dp': mean_dx,
             'mean_abs_dp': mean_abs_dx,
             'med_abs_dp': med_abs_dx,
             'std_dp': std_dx,
             'kurt_dp': kurt_dx,
             'survival': survival,
             'hazards': hazards,
             }

    return stats


def calibrate_model(targs, parameters, initial_guess, nx=500, x_grid=None):
    # targs is a dict of moment names and values
    # parameters is a dict of parameters
    # initial_guess is a dict with initial guesses

    calibrated_parameters = initial_guess.keys()
    x0 = np.array([v for v in initial_guess.values()])

    def difference(x):
        par = parameters.copy()
        par.update({k: v for k, v in zip(calibrated_parameters, x)})
        ss = steady_state(par, nx, x_grid=x_grid)
        diff = np.array([ss['stats'][k] - targs[k] for k in targs.keys()])
        return diff

    results = opt.least_squares(difference, x0=x0)
    par_out = parameters.copy()
    par_out.update({k: v for k, v in zip(calibrated_parameters, results.x)})
    ss = steady_state(par_out, nx, x_grid=x_grid)

    stats = ss['stats']

    all_close = np.all([np.abs(stats[k] - targs[k]) < 5e-4 for k in targs.keys()])
    if not all_close:
        print('Calibration unsuccessful.')

    return par_out, ss


def jacobian_distance(theta, M_target, beta, T, Price, Nominal, Absolute, last_iteration=False):
    # This is the objective function of the Jacobian approximation
    # ma_coefs is the vector of coefficients of the MA process that maximizes the forecasting error
    if Nominal:
        M_approx = td.calvo_jacobian(theta, beta, T)
        if not Price:
            M_approx[:, 1:] = M_approx[:, 1:]  - M_approx[:, :-1]
    else:
        kappa = (1 - theta) * (1 - beta * theta) / theta
        M_approx = td.calvo_PC(kappa, beta, T)
        if Price:
            M_approx = np.tril(np.ones(T)) @ M_approx

    if Absolute:
        diff = M_target - M_approx
    else:
        diff = np.linalg.solve(M_target.T, M_approx.T).T - np.eye(T)

    dist = np.sqrt(scipy.linalg.eigh(diff.T @ diff, eigvals_only=True, eigvals=(T - 1, T - 1)))[0]  # np.linalg.norm(diff, ord=2)

    if last_iteration:
        ma_coefs = scipy.linalg.eigh(diff.T @ diff, eigvals=(T - 1, T - 1))[1]
        dist_normalized = dist / np.sqrt(scipy.linalg.eigh(M_target.T @ M_target, eigvals_only=True, eigvals=(T - 1, T - 1)))[0]
    else:
        ma_coefs = []
        dist_normalized = None

    return dist, M_approx, ma_coefs, dist_normalized


def approx_jacobian(M_target, beta, Price=True, Nominal=True, Absolute=True):
    # approximates a given Jacobian using a Calvo model
    # M_target is the target Jacobian
    # Flex=True allows for a flex-price sector, otherwise approximates using Calvo
    # price_level=True indicates a price Jacobian, instead of inflation Jacobian
    # Nominal=True indicates a Jacobian wrt nominal marginal cost, as opposed to real marginal cost
    # if beta=None, then find the best-fitting beta as well

    T = len(M_target)

    if beta is not None:
        # res = opt.differential_evolution(lambda x: jacobian_distance(x, M_target, beta, T, Price, Nominal, Absolute)[0], bounds=[(1e-6, 1)], disp=False)
        # theta = res.x[0]
        res = opt.minimize_scalar(lambda x: jacobian_distance(x, M_target, beta, T, Price, Nominal, Absolute)[0], bounds=(0, 1), method='Bounded')
        theta = res.x
        dist, M_approx, ma_coefs, dist_normalized = jacobian_distance(theta, M_target, beta, T, Price, Nominal, Absolute, last_iteration=True)
        return M_approx, theta, dist, ma_coefs, dist_normalized

    else:
        res = opt.differential_evolution(lambda x: jacobian_distance(x[0], M_target, x[1], T, Price, Nominal, Absolute)[0], bounds=[(1e-6, 1), (1e-6, 1)], disp=False)
        theta = res.x[0]
        beta = res.x[1]
        dist, M_approx, ma_coefs, dist_normalized = jacobian_distance(theta, M_target, beta, T, Price, Nominal, Absolute, last_iteration=True)
        return M_approx, theta, beta, dist, ma_coefs, dist_normalized
